// Copyright Epic Games, Inc. All Rights Reserved.

#include <zencore/crypto.h>
#include <zencore/intmath.h>
#include <zencore/scopeguard.h>
#include <zencore/testing.h>

#include <string>
#include <string_view>

#ifndef ZEN_USE_OPENSSL
#	if ZEN_PLATFORM_WINDOWS
#		define ZEN_USE_OPENSSL 0
#	else
#		define ZEN_USE_OPENSSL 1
#	endif
#endif

ZEN_THIRD_PARTY_INCLUDES_START
#include <fmt/format.h>

#if ZEN_USE_OPENSSL
#	include <openssl/conf.h>
#	include <openssl/err.h>
#	include <openssl/evp.h>
#else
#	include <zencore/windows.h>
#	include <bcrypt.h>
#	define NT_SUCCESS(Status)	(((NTSTATUS)(Status)) >= 0)
#	define STATUS_UNSUCCESSFUL ((NTSTATUS)0xC0000001L)
#endif
ZEN_THIRD_PARTY_INCLUDES_END

namespace zen {

using namespace std::literals;

namespace crypto {
	enum class TransformMode : uint32_t
	{
		Decrypt,
		Encrypt
	};

#if ZEN_USE_OPENSSL

	class EvpContext
	{
	public:
		EvpContext() : m_Ctx(EVP_CIPHER_CTX_new()) {}
		~EvpContext() { EVP_CIPHER_CTX_free(m_Ctx); }

		operator EVP_CIPHER_CTX*() { return m_Ctx; }

	private:
		EVP_CIPHER_CTX* m_Ctx;
	};

	MemoryView Transform(TransformMode				 Mode,
						 MemoryView					 Key,
						 MemoryView					 IV,
						 MemoryView					 In,
						 MutableMemoryView			 Out,
						 std::optional<std::string>& Reason)
	{
		const EVP_CIPHER* Cipher = EVP_aes_256_cbc();

		ZEN_ASSERT(Cipher != nullptr);

		EvpContext Ctx;

		int Err = EVP_CipherInit_ex(Ctx,
									Cipher,
									nullptr,
									reinterpret_cast<const unsigned char*>(Key.GetData()),
									reinterpret_cast<const unsigned char*>(IV.GetData()),
									static_cast<int>(Mode));

		if (Err != 1)
		{
			Reason = fmt::format("failed to initialize cipher, error code '{}'", Err);

			return MemoryView();
		}

		int EncryptedBytes		= 0;
		int TotalEncryptedBytes = 0;

		Err = EVP_CipherUpdate(Ctx,
							   reinterpret_cast<unsigned char*>(Out.GetData()),
							   &EncryptedBytes,
							   reinterpret_cast<const unsigned char*>(In.GetData()),
							   static_cast<int>(In.GetSize()));

		if (Err != 1)
		{
			Reason = fmt::format("update crypto transform failed, error code '{}'", Err);

			return MemoryView();
		}

		TotalEncryptedBytes			= EncryptedBytes;
		MutableMemoryView Remaining = Out.RightChop(EncryptedBytes);

		EncryptedBytes = static_cast<int>(Remaining.GetSize());

		Err = EVP_CipherFinal(Ctx, reinterpret_cast<unsigned char*>(Remaining.GetData()), &EncryptedBytes);

		if (Err != 1)
		{
			Reason = fmt::format("finalize crypto transform failed, error code '{}'", Err);

			return MemoryView();
		}

		TotalEncryptedBytes += EncryptedBytes;

		return Out.Left(TotalEncryptedBytes);
	}
#else
	MemoryView Transform(TransformMode				 Mode,
						 MemoryView					 Key,
						 MemoryView					 IV,
						 MemoryView					 In,
						 MutableMemoryView			 Out,
						 std::optional<std::string>& Reason)
	{
		BCRYPT_ALG_HANDLE hAesAlg = NULL;
		NTSTATUS		  Status  = STATUS_UNSUCCESSFUL;

		// Open an algorithm handle.
		if (!NT_SUCCESS(Status = BCryptOpenAlgorithmProvider(&hAesAlg, BCRYPT_AES_ALGORITHM, NULL, 0)))
		{
			Reason = fmt::format("Error 0x{:08x} returned by BCryptGetProperty"sv, Status);
			return {};
		}

		auto _ = MakeGuard([hAesAlg] { BCryptCloseAlgorithmProvider(hAesAlg, 0); });

		DWORD cbData = 0;

		DWORD cbBlockLen = 0;
		if (!NT_SUCCESS(Status = BCryptGetProperty(hAesAlg, BCRYPT_BLOCK_LENGTH, (PBYTE)&cbBlockLen, sizeof(DWORD), &cbData, 0)))
		{
			Reason = fmt::format("Error 0x{:08x} returned  by BCryptGetProperty"sv, Status);
			return {};
		}

		if (cbBlockLen > IV.GetSize())
		{
			Reason = "block length is longer than the provided IV length"sv;

			return {};
		}

		AesIV128Bit MutableIV = AesIV128Bit::FromMemoryView(IV);

		if (!NT_SUCCESS(
				Status = BCryptSetProperty(hAesAlg, BCRYPT_CHAINING_MODE, (PBYTE)BCRYPT_CHAIN_MODE_CBC, sizeof(BCRYPT_CHAIN_MODE_CBC), 0)))
		{
			Reason = fmt::format("Error 0x{:08x} returned by BCryptSetProperty"sv, Status);
			return {};
		}

		DWORD cbKeyObject = 0;
		if (!NT_SUCCESS(Status = BCryptGetProperty(hAesAlg, BCRYPT_OBJECT_LENGTH, (PBYTE)&cbKeyObject, sizeof(DWORD), &cbData, 0)))
		{
			Reason = fmt::format("Error 0x{:08x} returned  by BCryptGetProperty"sv, Status);
			return {};
		}

		PBYTE pbKeyObject = (PBYTE)Memory::Alloc(cbKeyObject);
		if (NULL == pbKeyObject)
		{
			Reason = fmt::format("memory allocation failed");
			return {};
		}

		auto __ = MakeGuard([pbKeyObject] { Memory::Free(pbKeyObject); });

		BCRYPT_KEY_HANDLE hKey = NULL;

		if (!NT_SUCCESS(Status = BCryptGenerateSymmetricKey(hAesAlg,
															&hKey,
															pbKeyObject,
															cbKeyObject,
															(PBYTE)Key.GetData(),
															(ULONG)Key.GetSize(),
															/* flags */ 0)))
		{
			Reason = fmt::format("Error 0x{:08x} returned by BCryptGenerateSymmetricKey"sv, Status);
			return {};
		}

		auto ___ = MakeGuard([hKey] { BCryptDestroyKey(hKey); });

		if (Mode == TransformMode::Encrypt)
		{
			DWORD CipherTextByteCount = 0;
			if (NT_SUCCESS(Status = BCryptEncrypt(hKey,
												  (PUCHAR)In.GetData(),
												  (ULONG)In.GetSize(),
												  NULL,
												  (PUCHAR)MutableIV.GetView().GetData(),
												  cbBlockLen,
												  NULL,
												  0,
												  &CipherTextByteCount,
												  BCRYPT_BLOCK_PADDING)))
			{
				if (Out.GetSize() < CipherTextByteCount)
				{
					Reason = "invalid output buffer size";
					return {};
				}

				if (NT_SUCCESS(Status = BCryptEncrypt(hKey,
													  (PUCHAR)In.GetData(),
													  (ULONG)In.GetSize(),
													  NULL,
													  (PUCHAR)MutableIV.GetView().GetData(),
													  cbBlockLen,
													  (PUCHAR)Out.GetData(),
													  (ULONG)Out.GetSize(),
													  &CipherTextByteCount,
													  BCRYPT_BLOCK_PADDING)))
				{
					return Out.Left(CipherTextByteCount);
				}
			}

			Reason = fmt::format("Error 0x{:08x} returned by BCryptEncrypt", Status);
			return {};
		}
		else
		{
			DWORD PlainTextByteCount = 0;

			//
			// Get the output buffer size.
			//
			if (NT_SUCCESS(Status = BCryptDecrypt(hKey,
												  (PUCHAR)In.GetData(),
												  (ULONG)In.GetSize(),
												  NULL,
												  (PUCHAR)MutableIV.GetView().GetData(),
												  cbBlockLen,
												  NULL,
												  0,
												  &PlainTextByteCount,
												  BCRYPT_BLOCK_PADDING)))
			{
				if (Out.GetSize() < PlainTextByteCount)
				{
					Reason = "invalid output buffer size"sv;
					return {};
				}

				if (NT_SUCCESS(Status = BCryptDecrypt(hKey,
													  (PUCHAR)In.GetData(),
													  (ULONG)In.GetSize(),
													  NULL,
													  (PUCHAR)MutableIV.GetView().GetData(),
													  cbBlockLen,
													  (PUCHAR)Out.GetData(),
													  (ULONG)Out.GetSize(),
													  &PlainTextByteCount,
													  BCRYPT_BLOCK_PADDING)))
				{
					return Out.Left(PlainTextByteCount);
				}
			}

			Reason = fmt::format("Error 0x{:08x} returned by BCryptDecrypt"sv, Status);
			return {};
		}
	}
#endif

	bool ValidateKeyAndIV(const AesKey256Bit& Key, const AesIV128Bit& IV, std::optional<std::string>& Reason)
	{
		if (Key.IsValid() == false)
		{
			Reason = "invalid key"sv;

			return false;
		}

		if (IV.IsValid() == false)
		{
			Reason = "invalid initialization vector"sv;

			return false;
		}

		return true;
	}

}  // namespace crypto

MemoryView
Aes::Encrypt(const AesKey256Bit& Key, const AesIV128Bit& IV, MemoryView In, MutableMemoryView Out, std::optional<std::string>& Reason)
{
	if (crypto::ValidateKeyAndIV(Key, IV, Reason) == false)
	{
		return MemoryView();
	}

	return crypto::Transform(crypto::TransformMode::Encrypt, Key.GetView(), IV.GetView(), In, Out, Reason);
}

MemoryView
Aes::Decrypt(const AesKey256Bit& Key, const AesIV128Bit& IV, MemoryView In, MutableMemoryView Out, std::optional<std::string>& Reason)
{
	if (crypto::ValidateKeyAndIV(Key, IV, Reason) == false)
	{
		return MemoryView();
	}

	return crypto::Transform(crypto::TransformMode::Decrypt, Key.GetView(), IV.GetView(), In, Out, Reason);
}

#if ZEN_WITH_TESTS

void
crypto_forcelink()
{
}

TEST_CASE("crypto.bits")
{
	using CryptoBits256Bit = CryptoBits<256>;

	CryptoBits256Bit Bits;

	CHECK(Bits.IsNull());
	CHECK(Bits.IsValid() == false);

	CHECK(Bits.GetBitCount() == 256);
	CHECK(Bits.GetSize() == 32);

	Bits = CryptoBits256Bit::FromString("Addff"sv);
	CHECK(Bits.IsValid() == false);

	Bits = CryptoBits256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv);
	CHECK(Bits.IsValid());

	auto SmallerBits = CryptoBits<128>::FromString("abcdefghijklmnopqrstuvxyz0123456"sv);
	CHECK(SmallerBits.IsValid() == false);
}

TEST_CASE("crypto.aes")
{
	SUBCASE("basic")
	{
		const uint8_t	   InitVector[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
		const AesKey256Bit Key			= AesKey256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv);
		const AesIV128Bit  IV			= AesIV128Bit::FromMemoryView(MakeMemoryView(InitVector));

		std::string_view PlainText = "The quick brown fox jumps over the lazy dog"sv;

		std::vector<uint8_t>	   EncryptionBuffer;
		std::vector<uint8_t>	   DecryptionBuffer;
		std::optional<std::string> Reason;

		EncryptionBuffer.resize(PlainText.size() + Aes::BlockSize);
		DecryptionBuffer.resize(PlainText.size() + Aes::BlockSize);

		MemoryView EncryptedView = Aes::Encrypt(Key, IV, MakeMemoryView(PlainText), MakeMutableMemoryView(EncryptionBuffer), Reason);
		MemoryView DecryptedView = Aes::Decrypt(Key, IV, EncryptedView, MakeMutableMemoryView(DecryptionBuffer), Reason);

		std::string_view EncryptedDecryptedText =
			std::string_view(reinterpret_cast<const char*>(DecryptedView.GetData()), DecryptedView.GetSize());

		CHECK(EncryptedDecryptedText == PlainText);
	}
}

#endif

}  // namespace zen
